{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RAPTOR: Recursive Abstractive Processing and Thematic Organization for Retrieval\n",
    "\n",
    "## Overview\n",
    "RAPTOR is an advanced information retrieval and question-answering system that combines hierarchical document summarization, embedding-based retrieval, and contextual answer generation. It aims to efficiently handle large document collections by creating a multi-level tree of summaries, allowing for both broad and detailed information retrieval.\n",
    "\n",
    "## Motivation\n",
    "Traditional retrieval systems often struggle with large document sets, either missing important details or getting overwhelmed by irrelevant information. RAPTOR addresses this by creating a hierarchical structure of the document collection, allowing it to navigate between high-level concepts and specific details as needed.\n",
    "\n",
    "## Key Components\n",
    "1. **Tree Building**: Creates a hierarchical structure of document summaries.\n",
    "2. **Embedding and Clustering**: Organizes documents and summaries based on semantic similarity.\n",
    "3. **Vectorstore**: Efficiently stores and retrieves document and summary embeddings.\n",
    "4. **Contextual Retriever**: Selects the most relevant information for a given query.\n",
    "5. **Answer Generation**: Produces coherent responses based on retrieved information.\n",
    "\n",
    "## Method Details\n",
    "\n",
    "### Tree Building\n",
    "1. Start with original documents at level 0.\n",
    "2. For each level:\n",
    "   - Embed the texts using a language model.\n",
    "   - Cluster the embeddings (e.g., using Gaussian Mixture Models).\n",
    "   - Generate summaries for each cluster.\n",
    "   - Use these summaries as the texts for the next level.\n",
    "3. Continue until reaching a single summary or a maximum level.\n",
    "\n",
    "### Embedding and Retrieval\n",
    "1. Embed all documents and summaries from all levels of the tree.\n",
    "2. Store these embeddings in a vectorstore (e.g., FAISS) for efficient similarity search.\n",
    "3. For a given query:\n",
    "   - Embed the query.\n",
    "   - Retrieve the most similar documents/summaries from the vectorstore.\n",
    "\n",
    "### Contextual Compression\n",
    "1. Take the retrieved documents/summaries.\n",
    "2. Use a language model to extract only the most relevant parts for the given query.\n",
    "\n",
    "### Answer Generation\n",
    "1. Combine the relevant parts into a context.\n",
    "2. Use a language model to generate an answer based on this context and the original query.\n",
    "\n",
    "## Benefits of this Approach\n",
    "1. **Scalability**: Can handle large document collections by working with summaries at different levels.\n",
    "2. **Flexibility**: Capable of providing both high-level overviews and specific details.\n",
    "3. **Context-Awareness**: Retrieves information from the most appropriate level of abstraction.\n",
    "4. **Efficiency**: Uses embeddings and vectorstore for fast retrieval.\n",
    "5. **Traceability**: Maintains links between summaries and original documents, allowing for source verification.\n",
    "\n",
    "## Conclusion\n",
    "RAPTOR represents a significant advancement in information retrieval and question-answering systems. By combining hierarchical summarization with embedding-based retrieval and contextual answer generation, it offers a powerful and flexible approach to handling large document collections. The system's ability to navigate different levels of abstraction allows it to provide relevant and contextually appropriate answers to a wide range of queries.\n",
    "\n",
    "While RAPTOR shows great promise, future work could focus on optimizing the tree-building process, improving summary quality, and enhancing the retrieval mechanism to better handle complex, multi-faceted queries. Additionally, integrating this approach with other AI technologies could lead to even more sophisticated information processing systems."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"text-align: center;\">\n",
    "\n",
    "<img src=\"../images/raptor.svg\" alt=\"RAPTOR\" style=\"width:100%; height:auto;\">\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports and Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from typing import List, Dict, Any\n",
    "from sklearn.mixture import GaussianMixture\n",
    "from langchain.chains.llm import LLMChain\n",
    "from langchain.embeddings import OpenAIEmbeddings\n",
    "from langchain.vectorstores import FAISS\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "from langchain.retrievers import ContextualCompressionRetriever\n",
    "from langchain.retrievers.document_compressors import LLMChainExtractor\n",
    "from langchain.schema import AIMessage\n",
    "from langchain.docstore.document import Document\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import logging\n",
    "import os\n",
    "import sys\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
    "from helper_functions import *\n",
    "from evaluation.evalute_rag import *\n",
    "\n",
    "# Load environment variables from a .env file\n",
    "load_dotenv()\n",
    "\n",
    "# Set the OpenAI API key environment variable\n",
    "os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define logging, llm and embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up logging\n",
    "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
    "\n",
    "embeddings = OpenAIEmbeddings()\n",
    "llm = ChatOpenAI(model_name=\"gpt-4o-mini\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper Functions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_text(item):\n",
    "    \"\"\"Extract text content from either a string or an AIMessage object.\"\"\"\n",
    "    if isinstance(item, AIMessage):\n",
    "        return item.content\n",
    "    return item\n",
    "\n",
    "def embed_texts(texts: List[str]) -> List[List[float]]:\n",
    "    \"\"\"Embed texts using OpenAIEmbeddings.\"\"\"\n",
    "    logging.info(f\"Embedding {len(texts)} texts\")\n",
    "    return embeddings.embed_documents([extract_text(text) for text in texts])\n",
    "\n",
    "def perform_clustering(embeddings: np.ndarray, n_clusters: int = 10) -> np.ndarray:\n",
    "    \"\"\"Perform clustering on embeddings using Gaussian Mixture Model.\"\"\"\n",
    "    logging.info(f\"Performing clustering with {n_clusters} clusters\")\n",
    "    gm = GaussianMixture(n_components=n_clusters, random_state=42)\n",
    "    return gm.fit_predict(embeddings)\n",
    "\n",
    "def summarize_texts(texts: List[str]) -> str:\n",
    "    \"\"\"Summarize a list of texts using OpenAI.\"\"\"\n",
    "    logging.info(f\"Summarizing {len(texts)} texts\")\n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Summarize the following text concisely:\\n\\n{text}\"\n",
    "    )\n",
    "    chain = prompt | llm\n",
    "    input_data = {\"text\": texts}\n",
    "    return chain.invoke(input_data)\n",
    "\n",
    "def visualize_clusters(embeddings: np.ndarray, labels: np.ndarray, level: int):\n",
    "    \"\"\"Visualize clusters using PCA.\"\"\"\n",
    "    from sklearn.decomposition import PCA\n",
    "    pca = PCA(n_components=2)\n",
    "    reduced_embeddings = pca.fit_transform(embeddings)\n",
    "    \n",
    "    plt.figure(figsize=(10, 8))\n",
    "    scatter = plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], c=labels, cmap='viridis')\n",
    "    plt.colorbar(scatter)\n",
    "    plt.title(f'Cluster Visualization - Level {level}')\n",
    "    plt.xlabel('First Principal Component')\n",
    "    plt.ylabel('Second Principal Component')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RAPTOR Core Function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def build_raptor_tree(texts: List[str], max_levels: int = 3) -> Dict[int, pd.DataFrame]:\n",
    "    \"\"\"Build the RAPTOR tree structure with level metadata and parent-child relationships.\"\"\"\n",
    "    results = {}\n",
    "    current_texts = [extract_text(text) for text in texts]\n",
    "    current_metadata = [{\"level\": 0, \"origin\": \"original\", \"parent_id\": None} for _ in texts]\n",
    "    \n",
    "    for level in range(1, max_levels + 1):\n",
    "        logging.info(f\"Processing level {level}\")\n",
    "        \n",
    "        embeddings = embed_texts(current_texts)\n",
    "        n_clusters = min(10, len(current_texts) // 2)\n",
    "        cluster_labels = perform_clustering(np.array(embeddings), n_clusters)\n",
    "        \n",
    "        df = pd.DataFrame({\n",
    "            'text': current_texts,\n",
    "            'embedding': embeddings,\n",
    "            'cluster': cluster_labels,\n",
    "            'metadata': current_metadata\n",
    "        })\n",
    "        \n",
    "        results[level-1] = df\n",
    "        \n",
    "        summaries = []\n",
    "        new_metadata = []\n",
    "        for cluster in df['cluster'].unique():\n",
    "            cluster_docs = df[df['cluster'] == cluster]\n",
    "            cluster_texts = cluster_docs['text'].tolist()\n",
    "            cluster_metadata = cluster_docs['metadata'].tolist()\n",
    "            summary = summarize_texts(cluster_texts)\n",
    "            summaries.append(summary)\n",
    "            new_metadata.append({\n",
    "                \"level\": level,\n",
    "                \"origin\": f\"summary_of_cluster_{cluster}_level_{level-1}\",\n",
    "                \"child_ids\": [meta.get('id') for meta in cluster_metadata],\n",
    "                \"id\": f\"summary_{level}_{cluster}\"\n",
    "            })\n",
    "        \n",
    "        current_texts = summaries\n",
    "        current_metadata = new_metadata\n",
    "        \n",
    "        if len(current_texts) <= 1:\n",
    "            results[level] = pd.DataFrame({\n",
    "                'text': current_texts,\n",
    "                'embedding': embed_texts(current_texts),\n",
    "                'cluster': [0],\n",
    "                'metadata': current_metadata\n",
    "            })\n",
    "            logging.info(f\"Stopping at level {level} as we have only one summary\")\n",
    "            break\n",
    "    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vectorstore Function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_vectorstore(tree_results: Dict[int, pd.DataFrame]) -> FAISS:\n",
    "    \"\"\"Build a FAISS vectorstore from all texts in the RAPTOR tree.\"\"\"\n",
    "    all_texts = []\n",
    "    all_embeddings = []\n",
    "    all_metadatas = []\n",
    "    \n",
    "    for level, df in tree_results.items():\n",
    "        all_texts.extend([str(text) for text in df['text'].tolist()])\n",
    "        all_embeddings.extend([embedding.tolist() if isinstance(embedding, np.ndarray) else embedding for embedding in df['embedding'].tolist()])\n",
    "        all_metadatas.extend(df['metadata'].tolist())\n",
    "    \n",
    "    logging.info(f\"Building vectorstore with {len(all_texts)} texts\")\n",
    "    \n",
    "    # Create Document objects manually to ensure correct types\n",
    "    documents = [Document(page_content=str(text), metadata=metadata) \n",
    "                 for text, metadata in zip(all_texts, all_metadatas)]\n",
    "    \n",
    "    return FAISS.from_documents(documents, embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define tree traversal retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tree_traversal_retrieval(query: str, vectorstore: FAISS, k: int = 3) -> List[Document]:\n",
    "    \"\"\"Perform tree traversal retrieval.\"\"\"\n",
    "    query_embedding = embeddings.embed_query(query)\n",
    "    \n",
    "    def retrieve_level(level: int, parent_ids: List[str] = None) -> List[Document]:\n",
    "        if parent_ids:\n",
    "            docs = vectorstore.similarity_search_by_vector_with_relevance_scores(\n",
    "                query_embedding,\n",
    "                k=k,\n",
    "                filter=lambda meta: meta['level'] == level and meta['id'] in parent_ids\n",
    "            )\n",
    "        else:\n",
    "            docs = vectorstore.similarity_search_by_vector_with_relevance_scores(\n",
    "                query_embedding,\n",
    "                k=k,\n",
    "                filter=lambda meta: meta['level'] == level\n",
    "            )\n",
    "        \n",
    "        if not docs or level == 0:\n",
    "            return docs\n",
    "        \n",
    "        child_ids = [doc.metadata.get('child_ids', []) for doc, _ in docs]\n",
    "        child_ids = [item for sublist in child_ids for item in sublist]  # Flatten the list\n",
    "        \n",
    "        child_docs = retrieve_level(level - 1, child_ids)\n",
    "        return docs + child_docs\n",
    "    \n",
    "    max_level = max(doc.metadata['level'] for doc in vectorstore.docstore.values())\n",
    "    return retrieve_level(max_level)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Retriever\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_retriever(vectorstore: FAISS) -> ContextualCompressionRetriever:\n",
    "    \"\"\"Create a retriever with contextual compression.\"\"\"\n",
    "    logging.info(\"Creating contextual compression retriever\")\n",
    "    base_retriever = vectorstore.as_retriever()\n",
    "    \n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Given the following context and question, extract only the relevant information for answering the question:\\n\\n\"\n",
    "        \"Context: {context}\\n\"\n",
    "        \"Question: {question}\\n\\n\"\n",
    "        \"Relevant Information:\"\n",
    "    )\n",
    "    \n",
    "    extractor = LLMChainExtractor.from_llm(llm, prompt=prompt)\n",
    "    \n",
    "    return ContextualCompressionRetriever(\n",
    "        base_compressor=extractor,\n",
    "        base_retriever=base_retriever\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define hierarchical retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hierarchical_retrieval(query: str, retriever: ContextualCompressionRetriever, max_level: int) -> List[Document]:\n",
    "    \"\"\"Perform hierarchical retrieval starting from the highest level, handling potential None values.\"\"\"\n",
    "    all_retrieved_docs = []\n",
    "    \n",
    "    for level in range(max_level, -1, -1):\n",
    "        # Retrieve documents from the current level\n",
    "        level_docs = retriever.get_relevant_documents(\n",
    "            query,\n",
    "            filter=lambda meta: meta['level'] == level\n",
    "        )\n",
    "        all_retrieved_docs.extend(level_docs)\n",
    "        \n",
    "        # If we've found documents, retrieve their children from the next level down\n",
    "        if level_docs and level > 0:\n",
    "            child_ids = [doc.metadata.get('child_ids', []) for doc in level_docs]\n",
    "            child_ids = [item for sublist in child_ids for item in sublist if item is not None]  # Flatten and filter None\n",
    "            \n",
    "            if child_ids:  # Only modify query if there are valid child IDs\n",
    "                child_query = f\" AND id:({' OR '.join(str(id) for id in child_ids)})\"\n",
    "                query += child_query\n",
    "    \n",
    "    return all_retrieved_docs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RAPTOR Query Process (Online Process)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "def raptor_query(query: str, retriever: ContextualCompressionRetriever, max_level: int) -> Dict[str, Any]:\n",
    "    \"\"\"Process a query using the RAPTOR system with hierarchical retrieval.\"\"\"\n",
    "    logging.info(f\"Processing query: {query}\")\n",
    "    \n",
    "    relevant_docs = hierarchical_retrieval(query, retriever, max_level)\n",
    "    \n",
    "    doc_details = []\n",
    "    for i, doc in enumerate(relevant_docs, 1):\n",
    "        doc_details.append({\n",
    "            \"index\": i,\n",
    "            \"content\": doc.page_content,\n",
    "            \"metadata\": doc.metadata,\n",
    "            \"level\": doc.metadata.get('level', 'Unknown'),\n",
    "            \"similarity_score\": doc.metadata.get('score', 'N/A')\n",
    "        })\n",
    "    \n",
    "    context = \"\\n\\n\".join([doc.page_content for doc in relevant_docs])\n",
    "    \n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Given the following context, please answer the question:\\n\\n\"\n",
    "        \"Context: {context}\\n\\n\"\n",
    "        \"Question: {question}\\n\\n\"\n",
    "        \"Answer:\"\n",
    "    )\n",
    "    chain = LLMChain(llm=llm, prompt=prompt)\n",
    "    answer = chain.run(context=context, question=query)\n",
    "    \n",
    "    logging.info(\"Query processing completed\")\n",
    "    \n",
    "    result = {\n",
    "        \"query\": query,\n",
    "        \"retrieved_documents\": doc_details,\n",
    "        \"num_docs_retrieved\": len(relevant_docs),\n",
    "        \"context_used\": context,\n",
    "        \"answer\": answer,\n",
    "        \"model_used\": llm.model_name,\n",
    "    }\n",
    "    \n",
    "    return result\n",
    "\n",
    "\n",
    "def print_query_details(result: Dict[str, Any]):\n",
    "    \"\"\"Print detailed information about the query process, including tree level metadata.\"\"\"\n",
    "    print(f\"Query: {result['query']}\")\n",
    "    print(f\"\\nNumber of documents retrieved: {result['num_docs_retrieved']}\")\n",
    "    print(f\"\\nRetrieved Documents:\")\n",
    "    for doc in result['retrieved_documents']:\n",
    "        print(f\"  Document {doc['index']}:\")\n",
    "        print(f\"    Content: {doc['content'][:100]}...\")  # Show first 100 characters\n",
    "        print(f\"    Similarity Score: {doc['similarity_score']}\")\n",
    "        print(f\"    Tree Level: {doc['metadata'].get('level', 'Unknown')}\")\n",
    "        print(f\"    Origin: {doc['metadata'].get('origin', 'Unknown')}\")\n",
    "        if 'child_docs' in doc['metadata']:\n",
    "            print(f\"    Number of Child Documents: {len(doc['metadata']['child_docs'])}\")\n",
    "        print()\n",
    "    \n",
    "    print(f\"\\nContext used for answer generation:\")\n",
    "    print(result['context_used'])\n",
    "    \n",
    "    print(f\"\\nGenerated Answer:\")\n",
    "    print(result['answer'])\n",
    "    \n",
    "    print(f\"\\nModel Used: {result['model_used']}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example Usage and Visualization\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define data folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../data/Understanding_Climate_Change.pdf\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Process texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PyPDFLoader(path)\n",
    "documents = loader.load()\n",
    "texts = [doc.page_content for doc in documents]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create RAPTOR components instances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the RAPTOR tree\n",
    "tree_results = build_raptor_tree(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build vectorstore\n",
    "vectorstore = build_vectorstore(tree_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create retriever\n",
    "retriever = create_retriever(vectorstore)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run a query and observe where it got the data from + results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run the pipeline\n",
    "max_level = 3  # Adjust based on your tree depth\n",
    "query = \"What is the greenhouse effect?\"\n",
    "result = raptor_query(query, retriever, max_level)\n",
    "print_query_details(result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
